package com.study.crypto.basic.utils;

import org.bouncycastle.asn1.*;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.engines.SM2Engine;
import org.bouncycastle.crypto.engines.SM2Engine.Mode;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.crypto.params.ParametersWithRandom;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
import org.bouncycastle.jce.spec.ECParameterSpec;
import org.bouncycastle.math.ec.ECPoint;
import org.bouncycastle.pqc.math.linearalgebra.ByteUtils;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.BigIntegers;

import java.io.IOException;
import java.security.InvalidParameterException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;

/**
 * @author Songjin
 * @since 2021-03-01 1:16
 */
public class SM2Utils {
    
    private SM2Utils() {
    }
    
    /**
     * sm2 加密算法
     * @param publicKey 公钥
     * @param inData    原文
     * @param mode      sm2 加密、解密模式
     * @return 密文
     */
    public static byte[] encrypt(PublicKey publicKey, byte[] inData, Mode mode) throws InvalidCipherTextException, IOException {
        BCECPublicKey   bcecPublicKey   = (BCECPublicKey) publicKey;
        ECParameterSpec ecParameterSpec = bcecPublicKey.getParameters();
        ECDomainParameters ecDomainParameters = new ECDomainParameters(ecParameterSpec.getCurve(), ecParameterSpec.getG(), ecParameterSpec.getN());
        ECPublicKeyParameters ecPublicKeyParameters = new ECPublicKeyParameters(bcecPublicKey.getQ(), ecDomainParameters);
        SM3Digest sm3Digest = new SM3Digest();
        SM2Engine sm2Engine = new SM2Engine(sm3Digest, mode);
        sm2Engine.init(true, new ParametersWithRandom(ecPublicKeyParameters, new SecureRandom()));
        byte[] encrypted = sm2Engine.processBlock(inData, 0, inData.length);
        int curveLength = (ecParameterSpec.getCurve().getFieldSize() + 7) / 8;
        int index_1 = 1 + 2 * curveLength;
        int index_2;
        switch (mode) {
            case C1C2C3:
                index_2 = index_1 + inData.length;
                break;
            case C1C3C2:
                index_2 = index_1 + sm3Digest.getDigestSize();
                break;
            default:
                throw new InvalidParameterException("SM2Engine.Mode参数值不合法");
        }
        // C1、C2、C3 分别为公钥数据、密文数据、摘要数据
        byte[] c1 = ByteUtils.subArray(encrypted, 0, index_1);
        byte[] c2_c3 = ByteUtils.subArray(encrypted, index_1, index_2);
        byte[] c3_c2 = ByteUtils.subArray(encrypted, index_2);
        byte[] c2, c3;
        switch (mode) {
            case C1C2C3:
                c2 = c2_c3;
                c3 = c3_c2;
                break;
            case C1C3C2:
                c3 = c2_c3;
                c2 = c3_c2;
                break;
            default:
                throw new InvalidParameterException("SM2Engine.Mode参数值不合法");
        }
        // 将 C1、C2、C3 封装成功 DER 编码结构
        ASN1EncodableVector vector = new ASN1EncodableVector();
        ECPoint c1P = ecPublicKeyParameters.getParameters().getCurve().decodePoint(c1);
        ASN1Integer x = new ASN1Integer(c1P.getXCoord().toBigInteger());
        ASN1Integer y = new ASN1Integer(c1P.getYCoord().toBigInteger());
        DEROctetString hash = new DEROctetString(c3);
        DEROctetString ciphertext = new DEROctetString(c2);
        vector.add(x);
        vector.add(y);
        vector.add(hash);
        vector.add(ciphertext);
        return new DERSequence(vector).getEncoded(ASN1Encoding.DER);
    }
    
    /**
     * 将 sm2 加密结构解包装
     * @param encrypted
     * @return
     * @throws IOException
     */
    public static ASN1EncodableVector decapsulateCipher(byte[] encrypted) throws IOException {
        ASN1Sequence sequence = (ASN1Sequence) ASN1Primitive.fromByteArray(encrypted);
        ASN1EncodableVector vector = new ASN1EncodableVector();
        vector.add(sequence.getObjectAt(0));
        vector.add(sequence.getObjectAt(1));
        vector.add(sequence.getObjectAt(2));
        vector.add(sequence.getObjectAt(3));
        return vector;
    }
    
    /**
     * SM2解密算法
     * @param privateKey 私钥
     * @param ciphertext 密文数据
     * @param mode       sm2 加密、解密模式
     * @return 原文
     */
    public static byte[] decrypt(PrivateKey privateKey, byte[] ciphertext, Mode mode) throws InvalidCipherTextException, IOException {
        try (ASN1InputStream inStr = new ASN1InputStream(ciphertext)) {
            ASN1Primitive   obj        = inStr.readObject();
            ASN1Sequence    enSeq      = ASN1Sequence.getInstance(obj);
            ASN1Primitive   bXObj      = enSeq.getObjectAt(0).toASN1Primitive();
            ASN1Integer     x          = (ASN1Integer) bXObj;
            byte[]          bX         = BigIntegers.asUnsignedByteArray(x.getValue());
            ASN1Primitive   bYObj      = enSeq.getObjectAt(1).toASN1Primitive();
            ASN1Integer     y          = (ASN1Integer) bYObj;
            byte[]          bY         = BigIntegers.asUnsignedByteArray(y.getValue());
            ASN1Primitive   bHashObj   = enSeq.getObjectAt(2).toASN1Primitive();
            DEROctetString  hashOctStr = (DEROctetString) bHashObj;
            byte[]          bHash      = hashOctStr.getOctets();
            ASN1Primitive   cObj       = enSeq.getObjectAt(3).toASN1Primitive();
            DEROctetString  cOctStr    = (DEROctetString) cObj;
            byte[]          c          = cOctStr.getOctets();
            byte[]          head       = new byte[]{4};
            byte[]          buf        = new byte[64];
            System.arraycopy(bX, 0, buf, 32 - bX.length, bX.length);
            System.arraycopy(bY, 0, buf, 64 - bY.length, bY.length);
            byte[] point = Arrays.concatenate(head, buf);
            byte[] in;
            if (Mode.C1C2C3.equals(mode)) {
                in = Arrays.concatenate(point, c, bHash);
            } else if (Mode.C1C3C2.equals(mode)) {
                in = Arrays.concatenate(point, bHash, c);
            } else {
                throw new InvalidParameterException("SM2Engine.Mode参数值不合法");
            }
        
            BCECPrivateKey         bcecPrivateKey         = (BCECPrivateKey) privateKey;
            ECParameterSpec        ecParameterSpec        = bcecPrivateKey.getParameters();
            ECDomainParameters     ecDomainParameters     = new ECDomainParameters(ecParameterSpec.getCurve(), ecParameterSpec.getG(), ecParameterSpec.getN());
            ECPrivateKeyParameters ecPrivateKeyParameters = new ECPrivateKeyParameters(bcecPrivateKey.getD(), ecDomainParameters);
            SM2Engine              sm2Engine              = new SM2Engine(mode);
            sm2Engine.init(false, ecPrivateKeyParameters);
            return sm2Engine.processBlock(in, 0, in.length);
        }
    }
}
